import torch


from Conversion_utils import *
from Conversion_utils_opt import *
from transformers import GPT2LMHeadModel, OPTForCausalLM
from transformers import HfArgumentParser
from transformers import AutoConfig
from config import Construct_config



def nested_children(m: torch.nn.Module):
    children = dict(m.named_children())
    output = {}
    if children == {}:
        # if module has no children; m is last child! :O
        return m
    else:
        # look for children from children... to the last child!
        for name, child in children.items():
            try:
                output[name] = nested_children(child)
            except TypeError:
                output[name] = nested_children(child)
    return output



def Construct_NASgpt(model_args, construction_args):
    #parser = HfArgumentParser((ModelArguments, ConstructionArguments))
    #model_args, construction_args = parser.parse_args_into_dataclasses()
    model_config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir
    )
    
    
    #print (model_config)
    
    
    construction_args.activation_function = model_config.activation_function
    construction_args.seq_length = model_config.max_position_embeddings
    
    config = Construct_config(model_config, construction_args)

    
    
    
    
    if 'gpt' in model_args.model_name_or_path:
        model_fn = GPT2LMHeadModel
    elif 'opt' in model_args.model_name_or_path:
        model_fn = OPTForCausalLM
    else:
        raise NotImplmentedError
    
    model = model_fn.from_pretrained(
        model_args.model_name_or_path,
        config=model_config,
        cache_dir=model_args.cache_dir,
    )
    
    
    print ("....Constructing the model....") 
    
    if  model_args.construct_load_model:  
        print ("....Load constructed model....")
        checkpoint = torch.load(model_args.construct_model_path)
        model_config = checkpoint['model_config']
        config = checkpoint['construction_config']
    
    print (nested_children(model))
    if 'gpt' in model_args.model_name_or_path:
        constructed_model = NASgpt2(config, model_config, nested_children(model))
    elif 'opt' in model_args.model_name_or_path:
        constructed_model = NASopt(config, model_config, nested_children(model))
    else:
        raise NotImplmentedError
    
    if  model_args.construct_load_model:  
        constructed_model.load_state_dict(checkpoint['model_state_dict'])
    
    
    if model_args.construct_save_model:
            print ("....Store constructed model....")
            #torch.save(constructed_model.state_dict(), model_args.construct_model_path)
            torch.save({'model_state_dict': constructed_model.state_dict(),\
                        'model_config': model_config,\
                        'construction_config': config},\
                       model_args.construct_model_path,\
                      )


    tot = 0
    for parameter in constructed_model.parameters():
        tot += parameter.numel()
    print ("Total trainable parameters in constructed model", tot)
    #constructed_gpt2.to(model_args.device)
    #model.to(model_args.device)
    
    
    return (constructed_model, model, model_config, config)
    
    
    
#if __name__ == "__main__":
#    constructed_gpt2, simulated_gpt2, model_config, nest_model = Construct_NASgpt()
    
    
                 
    
    
    
